{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Whisper and Distil-Whisper model evaluation\n",
    "\n",
    "Here we perform evaluation of Whisper and Distil-Whisper speech-to-text models on a Common Voice 13 dataset.\n",
    "\n",
    "In order to run this notebook, first the [267-distil-whisper-asr](../../../notebooks/267-distil-whisper-asr/267-distil-whisper-asr.ipynb) notebook needs to be run to export `openai/whisper-large-v2` and `distil-whisper/distil-large-v2` models."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "from openvino import Core\n",
    "\n",
    "core = Core()\n",
    "\n",
    "device = widgets.Dropdown(\n",
    "    options=core.available_devices + [\"AUTO\"],\n",
    "    value=\"CPU\",\n",
    "    description=\"Device:\",\n",
    "    disabled=False,\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-25T16:01:57.368695900Z",
     "start_time": "2023-11-25T16:01:56.959444300Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:01:57.370189Z",
     "start_time": "2023-11-25T16:01:57.368695900Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "SAMPLING_RATE = 16000\n",
    "\n",
    "AVAILABLE_MODELS = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:01:57.373198100Z",
     "start_time": "2023-11-25T16:01:57.370189Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "daaf279c2b3e41b583ab74c01836b6e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Checkbox(value=True, description='Load Whisper large-v2 PyTorch')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b45b06e98822427ebc7310b3fd7d77de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Checkbox(value=True, description='Load Whisper large-v2 OpenVINO')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0e893199bce04655be75ad6b705e4493",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Checkbox(value=True, description='Load Whisper large-v2 OpenVINO Quantized')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6e5c6d529a245f38d04ab969a5aa83b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Checkbox(value=True, description='Load Distil-Whisper large-v2 PyTorch')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8be299c8f7e4dc79af042902c27bebe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Checkbox(value=True, description='Load Distil-Whisper large-v2 OpenVINO')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "649d5b25628c4534bdc1ba7fb3af65f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Checkbox(value=True, description='Load Distil-Whisper large-v2 OpenVINO Quantized')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "load_whisper_model_pt = widgets.Checkbox(value=True, description=f\"Load Whisper large-v2 PyTorch\", disabled=False)\n",
    "load_whisper_model_ov = widgets.Checkbox(value=True, description=f\"Load Whisper large-v2 OpenVINO\", disabled=False)\n",
    "load_whisper_model_ov_int8 = widgets.Checkbox(value=True, description=f\"Load Whisper large-v2 OpenVINO Quantized\", disabled=False)\n",
    "\n",
    "load_distil_whisper_model_pt = widgets.Checkbox(value=True, description=\"Load Distil-Whisper large-v2 PyTorch\", disabled=False)\n",
    "load_distil_whisper_model_ov = widgets.Checkbox(value=True, description=\"Load Distil-Whisper large-v2 OpenVINO\", disabled=False)\n",
    "load_distil_whisper_model_ov_int8 = widgets.Checkbox(value=True, description=\"Load Distil-Whisper large-v2 OpenVINO Quantized\", disabled=False)\n",
    "\n",
    "display(load_whisper_model_pt)\n",
    "display(load_whisper_model_ov)\n",
    "display(load_whisper_model_ov_int8)\n",
    "display(load_distil_whisper_model_pt)\n",
    "display(load_distil_whisper_model_ov)\n",
    "display(load_distil_whisper_model_ov_int8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Whisper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:02:13.484031600Z",
     "start_time": "2023-11-25T16:01:57.413211800Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq\n",
    "\n",
    "# PyTorch Whisper\n",
    "\n",
    "whisper_model_id = \"openai/whisper-large-v2\"\n",
    "\n",
    "if load_whisper_model_pt.value or load_whisper_model_ov.value or load_whisper_model_ov_int8.value:\n",
    "    whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)\n",
    "\n",
    "    if load_whisper_model_pt.value:\n",
    "        whisper_model_pt = AutoModelForSpeechSeq2Seq.from_pretrained(whisper_model_id).eval()\n",
    "        AVAILABLE_MODELS.append((whisper_model_pt, whisper_processor, f\"Whisper large-v2 PyTorch\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:02:32.634323800Z",
     "start_time": "2023-11-25T16:02:13.486032900Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda-11.7'\n",
      "Compiling the encoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "from optimum.intel.openvino import OVModelForSpeechSeq2Seq\n",
    "\n",
    "# OV FP32 Whisper\n",
    "\n",
    "DISTIL_WHISPER_DIR = Path(\"../../../notebooks/267-distil-whisper-asr\")\n",
    "WHISPER_OV_PATH = DISTIL_WHISPER_DIR / whisper_model_id.replace(\"/\", \"_\")\n",
    "\n",
    "\n",
    "if load_whisper_model_ov.value:\n",
    "    whisper_model_ov_fp32 = OVModelForSpeechSeq2Seq.from_pretrained(WHISPER_OV_PATH, compile=False).to(device.value)\n",
    "    whisper_model_ov_fp32.compile()\n",
    "    AVAILABLE_MODELS.append((whisper_model_ov_fp32, whisper_processor, f\"Whisper large-v2 OpenVINO\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:02:54.560878400Z",
     "start_time": "2023-11-25T16:02:32.735626100Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Compiling the encoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n"
     ]
    }
   ],
   "source": [
    "# OV INT8 Whisper\n",
    "\n",
    "WHISPER_OV_INT8_PATH = Path(f\"{WHISPER_OV_PATH}_quantized\")\n",
    "\n",
    "if load_whisper_model_ov_int8.value:\n",
    "    whisper_model_ov_int8 = OVModelForSpeechSeq2Seq.from_pretrained(WHISPER_OV_INT8_PATH, compile=False).to(device.value)\n",
    "    whisper_model_ov_int8.compile()\n",
    "    AVAILABLE_MODELS.append((whisper_model_ov_int8, whisper_processor, f\"Whisper large-v2 OpenVINO Quantized\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:02:54.576332500Z",
     "start_time": "2023-11-25T16:02:54.576332500Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def recognize_audio(model, processor, audio):\n",
    "    input_features = processor(\n",
    "        audio,\n",
    "        sampling_rate=SAMPLING_RATE,\n",
    "        return_tensors=\"pt\",\n",
    "    ).input_features\n",
    "    predicted_ids = model.generate(input_features)\n",
    "    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]\n",
    "\n",
    "    return transcription"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Distil Whisper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:03:01.059721900Z",
     "start_time": "2023-11-25T16:02:54.576332500Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq\n",
    "\n",
    "# PyTorch Distil Whisper\n",
    "\n",
    "distil_whisper_model_id = \"distil-whisper/distil-large-v2\"\n",
    "\n",
    "if load_distil_whisper_model_pt.value or load_distil_whisper_model_ov.value or load_distil_whisper_model_ov_int8.value:\n",
    "    distil_whisper_processor = AutoProcessor.from_pretrained(distil_whisper_model_id)\n",
    "\n",
    "    if load_distil_whisper_model_pt.value:\n",
    "        distil_whisper_model_pt = AutoModelForSpeechSeq2Seq.from_pretrained(distil_whisper_model_id).eval()\n",
    "        AVAILABLE_MODELS.append((distil_whisper_model_pt, distil_whisper_processor, \"Distil-Whisper large-v2 PyTorch\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:03:05.207141100Z",
     "start_time": "2023-11-25T16:03:01.065701400Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Compiling the encoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n"
     ]
    }
   ],
   "source": [
    "from optimum.intel.openvino import OVModelForSpeechSeq2Seq\n",
    "\n",
    "# OV FP32 Distil Whisper\n",
    "\n",
    "DISTIL_WHISPER_OV_PATH = DISTIL_WHISPER_DIR / distil_whisper_model_id.replace(\"/\", \"_\")\n",
    "\n",
    "if load_distil_whisper_model_ov.value:\n",
    "    distil_whisper_model_ov_fp32 = OVModelForSpeechSeq2Seq.from_pretrained(DISTIL_WHISPER_OV_PATH, compile=False).to(device.value)\n",
    "    distil_whisper_model_ov_fp32.compile()\n",
    "    AVAILABLE_MODELS.append((distil_whisper_model_ov_fp32, distil_whisper_processor, \"Distil-Whisper large-v2 OpenVINO\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-25T16:03:06.946384900Z",
     "start_time": "2023-11-25T16:03:05.218404900Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Compiling the encoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n",
      "Compiling the decoder to CPU ...\n"
     ]
    }
   ],
   "source": [
    "# OV INT8 Distil Whisper\n",
    "\n",
    "DISTIL_WHISPER_OV_INT8_PATH = Path(f\"{DISTIL_WHISPER_OV_PATH}_quantized\")\n",
    "\n",
    "if load_distil_whisper_model_ov_int8.value:\n",
    "    distil_whisper_model_ov_int8 = OVModelForSpeechSeq2Seq.from_pretrained(DISTIL_WHISPER_OV_INT8_PATH, compile=False).to(device.value)\n",
    "    distil_whisper_model_ov_int8.compile()\n",
    "    AVAILABLE_MODELS.append((distil_whisper_model_ov_int8, distil_whisper_processor, \"Distil-Whisper large-v2 OpenVINO Quantized\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Prepare test dataset\n",
    "\n",
    "In order not to download all splits for all languages of `common_voice_13` below we cache only the needed part of `en/test` split. This is performed only once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n",
      "Token is valid (permission: read).\n",
      "Your token has been saved to /home/nsavel/.cache/huggingface/token\n",
      "Login successful\n"
     ]
    }
   ],
   "source": [
    "from huggingface_hub import login\n",
    "import os\n",
    "\n",
    "if \"HF_TOKEN\" in os.environ:\n",
    "    login(os.environ[\"HF_TOKEN\"])\n",
    "else:\n",
    "    login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/datasets/load.py:1429: FutureWarning: The repository for mozilla-foundation/common_voice_13_0 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mozilla-foundation/common_voice_13_0\n",
      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
      "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "def resample(audio, src_sample_rate, dst_sample_rate):\n",
    "    \"\"\"\n",
    "    Resample audio to specific sample rate\n",
    "\n",
    "    Parameters:\n",
    "      audio: input audio signal\n",
    "      src_sample_rate: source audio sample rate\n",
    "      dst_sample_rate: destination audio sample rate\n",
    "    Returns:\n",
    "      resampled_audio: input audio signal resampled with dst_sample_rate\n",
    "    \"\"\"\n",
    "    if src_sample_rate == dst_sample_rate:\n",
    "        return audio\n",
    "    duration = audio.shape[0] / src_sample_rate\n",
    "    resampled_data = np.zeros(shape=(int(duration * dst_sample_rate)), dtype=np.float32)\n",
    "    x_old = np.linspace(0, duration, audio.shape[0], dtype=np.float32)\n",
    "    x_new = np.linspace(0, duration, resampled_data.shape[0], dtype=np.float32)\n",
    "    resampled_audio = np.interp(x_new, x_old, audio)\n",
    "    return resampled_audio.astype(np.float32)\n",
    "\n",
    "\n",
    "def save_parquet(df: pd.DataFrame, save_filepath: Path):\n",
    "    if save_filepath.exists():\n",
    "        df.to_parquet(save_filepath, engine=\"fastparquet\", append=True)\n",
    "    else:\n",
    "        df.to_parquet(save_filepath, engine=\"fastparquet\")\n",
    "\n",
    "\n",
    "TEST_DATASET_SIZE = 16372\n",
    "chunk_size = 1000\n",
    "test_dataset = load_dataset(\"mozilla-foundation/common_voice_13_0\", \"en\", split=\"test\", streaming=True)\n",
    "dataset_filepath = Path(\"common_voice_13_0_en_test.parquet\")\n",
    "\n",
    "if not dataset_filepath.exists():\n",
    "    columns = [\"audio_bytes\", \"transcription\"]\n",
    "    df = pd.DataFrame(columns=columns)\n",
    "    total_size = 0\n",
    "    for i, data_item in tqdm(enumerate(test_dataset), total=TEST_DATASET_SIZE, desc=\"Preparing test dataset\"):\n",
    "        audio_data = data_item[\"audio\"]\n",
    "        audio = resample(audio_data[\"array\"].astype(np.float32), audio_data[\"sampling_rate\"], SAMPLING_RATE)\n",
    "        df.loc[len(df)] = (audio.tobytes(), data_item[\"sentence\"])\n",
    "        if len(df) == chunk_size:\n",
    "            total_size += chunk_size\n",
    "            save_parquet(df, dataset_filepath)\n",
    "            df = pd.DataFrame(columns=columns)\n",
    "    if len(df) > 0:\n",
    "        total_size += len(df)\n",
    "        save_parquet(df, dataset_filepath)\n",
    "    assert total_size == TEST_DATASET_SIZE, f\"Acquired dataset size does not equal the expected size: {total_size} {TEST_DATASET_SIZE}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import pyarrow.parquet as pq\n",
    "import time\n",
    "from datetime import datetime\n",
    "from itertools import islice\n",
    "from jiwer import wer, wer_standardize\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "def calculate_transcription_time_and_accuracy(model, processor, dataset_size, model_name):\n",
    "    ground_truths = []\n",
    "    predictions = []\n",
    "    total_ground_truth_duration = 0\n",
    "    total_prediction_time = 0\n",
    "    parquet_file = pq.ParquetFile(dataset_filepath)\n",
    "    dataset_size = TEST_DATASET_SIZE if dataset_size == -1 else dataset_size\n",
    "    for batch in tqdm(islice(parquet_file.iter_batches(batch_size=1), dataset_size), desc=model_name, total=dataset_size):\n",
    "        data_item = batch.to_pandas().iloc[0].to_dict()\n",
    "        audio = np.frombuffer(data_item[\"audio_bytes\"], dtype=np.float32)\n",
    "        ground_truth = data_item[\"transcription\"]\n",
    "\n",
    "        audio_duration = audio.shape[0] / SAMPLING_RATE\n",
    "        total_ground_truth_duration += audio_duration\n",
    "\n",
    "        start_time = time.perf_counter()\n",
    "        transcription = recognize_audio(model, processor, audio)\n",
    "        end_time = time.perf_counter()\n",
    "        total_prediction_time += end_time - start_time\n",
    "\n",
    "        ground_truths.append(ground_truth)\n",
    "        predictions.append(transcription)\n",
    "\n",
    "    word_accuracy = (1 - wer(ground_truths, predictions, reference_transform=wer_standardize, hypothesis_transform=wer_standardize)) * 100\n",
    "    mean_performace_coefficient = total_ground_truth_duration / total_prediction_time\n",
    "    return word_accuracy, mean_performace_coefficient\n",
    "\n",
    "\n",
    "def evaluate(dataset_size, print_intermediate_results=False):\n",
    "    save_dir = Path(f\"results/eval_{dataset_size}_{datetime.now().replace(microsecond=0)}\")\n",
    "    save_dir.mkdir(parents=True, exist_ok=True)\n",
    "    log_file = open(save_dir / \"log.txt\", \"a\")\n",
    "\n",
    "    df = pd.DataFrame(columns=[\"Model\", \"Accuracy\", \"Real-time Performance\"])\n",
    "    for model, processor, model_name in AVAILABLE_MODELS:\n",
    "        accuracy, mean_performace_coefficient = calculate_transcription_time_and_accuracy(model, processor, dataset_size, model_name)\n",
    "        df.loc[len(df)] = (model_name, f\"{accuracy:.2f}%\", f\"{mean_performace_coefficient:.2f}x\")\n",
    "        if print_intermediate_results:\n",
    "            log_msg = f\"{model_name}: {accuracy:.2f}% {mean_performace_coefficient:.2f}x\"\n",
    "            print(log_msg)\n",
    "            log_file.write(log_msg + \"\\n\")\n",
    "            log_file.flush()\n",
    "\n",
    "    formatters = {}\n",
    "    for col in df.select_dtypes(\"object\"):\n",
    "        len_max = df[col].str.len().max()\n",
    "        formatters[col] = eval('lambda x: f\"{x:<{' + str(len_max) + '}s}\"')\n",
    "    df.to_csv(save_dir / f\"result.csv\", index=False)\n",
    "    log_msg = df.to_string(index=False, formatters=formatters, justify=\"left\")\n",
    "    print(log_msg)\n",
    "    log_file.write(log_msg + \"\\n\")\n",
    "    log_file.flush()\n",
    "\n",
    "\n",
    "evaluate(dataset_size=1)\n",
    "evaluate(dataset_size=10)\n",
    "evaluate(dataset_size=100)\n",
    "evaluate(dataset_size=-1, print_intermediate_results=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
